{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Complex-Valued Convolutions for Modulation Recognition using Deep Learning\n",
    "- Author: Jakob Krzyston\n",
    "- Date: 1/27/2020\n",
    "\n",
    "This code based on: https://github.com/radioML/examples/blob/master/modulation_recognition/RML2016.10a_VTCNN2_example.ipynb"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Import Packages"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os,random\n",
    "import numpy as np\n",
    "os.environ[\"KERAS_BACKEND\"] = \"tensorflow\"\n",
    "from keras.utils import np_utils\n",
    "import keras.models as models\n",
    "from keras.layers.core import Reshape,Dense,Dropout,Activation,Flatten,Lambda,Permute\n",
    "from keras.layers.noise import GaussianNoise\n",
    "from keras.layers.convolutional import Convolution2D, MaxPooling2D, ZeroPadding2D\n",
    "from keras.regularizers import *\n",
    "from keras.optimizers import adam\n",
    "import matplotlib.pyplot as plt\n",
    "%matplotlib inline\n",
    "import keras\n",
    "import pickle, random, time"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Load the dataset\n",
    "- data was downloaded from https://www.deepsig.io/datasets"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "Xd = pickle.load(open(\"RML2016.10a_dict.pkl\", 'rb'), encoding = 'latin1')\n",
    "test_snrs,mods = map(lambda j: sorted( list( set( map( lambda x: x[j], Xd.keys() ) ) ) ), [1,0])\n",
    "X = []\n",
    "lbl = []\n",
    "\n",
    "for mod in mods:\n",
    "    for snr in test_snrs:\n",
    "        X.append(Xd[(mod,snr)])\n",
    "        for i in range(Xd[(mod,snr)].shape[0]):  lbl.append((mod,snr))\n",
    "X = np.vstack(X)\n",
    "print(X.shape)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Partition Data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "np.random.seed(2019)\n",
    "n_examples = X.shape[0]\n",
    "n_train    = int(round(n_examples * 0.5))\n",
    "train_idx  = np.random.choice(range(0,n_examples), size=n_train, replace=False)\n",
    "test_idx   = list(set(range(0,n_examples))-set(train_idx))\n",
    "X_train    = X[train_idx]\n",
    "X_test     = X[test_idx]\n",
    "\n",
    "def to_onehot(yy):\n",
    "    yy1 = np.zeros([len(yy) ,max(yy)+1])\n",
    "    yy1[  np.arange(len(yy)),yy] = 1 # ?\n",
    "    return yy1\n",
    "Y_train = to_onehot(list(map(lambda x: mods.index(lbl[x][0]), train_idx)))\n",
    "Y_test = to_onehot(list(map(lambda x: mods.index(lbl[x][0]), test_idx)))\n",
    "\n",
    "in_shp = list(X_train.shape[1:])\n",
    "print(X_train.shape, in_shp)\n",
    "classes = mods"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Build the nets"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### CNN2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "dr = 0.5 # dropout rate (%)\n",
    "cnn2 = models.Sequential()\n",
    "cnn2.add(Reshape([1]+in_shp, input_shape=in_shp))\n",
    "cnn2.add(ZeroPadding2D((0, 2),data_format='channels_first'))\n",
    "cnn2.add(Convolution2D(256, (1, 3), padding='valid', activation=\"relu\", name=\"conv1\", kernel_initializer='glorot_uniform', data_format='channels_first'))#ch from 3->4\n",
    "cnn2.add(Dropout(dr))\n",
    "cnn2.add(Convolution2D(80, (2, 1), padding='valid', activation=\"relu\", name=\"conv2\", kernel_initializer='glorot_uniform', data_format='channels_first'))\n",
    "cnn2.add(Dropout(dr))\n",
    "cnn2.add(Flatten())\n",
    "cnn2.add(Dense(256, activation='relu', kernel_initializer='he_normal', name=\"dense1\"))\n",
    "cnn2.add(Dropout(dr))\n",
    "cnn2.add(Dense(len(classes), kernel_initializer='he_normal', name=\"dense2\"))\n",
    "cnn2.add(Activation('softmax'))\n",
    "cnn2.add(Reshape([len(classes)]))\n",
    "cnn2.compile(loss='categorical_crossentropy', optimizer='adam')\n",
    "cnn2.summary()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### CNN2-260"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "cnn2_260 = models.Sequential()\n",
    "cnn2_260.add(Reshape([1]+in_shp, input_shape=in_shp))\n",
    "cnn2_260.add(ZeroPadding2D((0, 2),data_format='channels_first'))\n",
    "cnn2_260.add(Convolution2D(256, (1, 3), padding='valid', activation=\"relu\", kernel_initializer='glorot_uniform', data_format='channels_first'))#ch from 3->4\n",
    "cnn2_260.add(Dropout(dr))\n",
    "cnn2_260.add(Convolution2D(80, (2, 1), padding='valid', activation=\"relu\", kernel_initializer='glorot_uniform', data_format='channels_first'))\n",
    "cnn2_260.add(Dropout(dr))\n",
    "cnn2_260.add(Flatten())\n",
    "cnn2_260.add(Dense(260, activation='relu', kernel_initializer='he_normal'))\n",
    "cnn2_260.add(Dropout(dr))\n",
    "cnn2_260.add(Dense(len(classes), kernel_initializer='he_normal'))\n",
    "cnn2_260.add(Activation('softmax'))\n",
    "cnn2_260.add(Reshape([len(classes)]))\n",
    "cnn2_260.compile(loss='categorical_crossentropy', optimizer='adam')\n",
    "cnn2_260.summary()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Complex"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Define the linear combination\n",
    "def LC(x):\n",
    "\ty = K.constant([0, 1, 0, -1, 0, 1],shape=[2,3])\n",
    "\treturn K.dot(x,K.transpose(y))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "complex_CNN = models.Sequential()\n",
    "complex_CNN.add(Reshape([1]+in_shp, input_shape=in_shp))\n",
    "complex_CNN.add(ZeroPadding2D((1, 2),data_format='channels_first'))\n",
    "complex_CNN.add(Convolution2D(256, (2, 3), padding='valid', activation='linear', name=\"conv1\", kernel_initializer='glorot_uniform', data_format='channels_first'))#ch from 3->4\n",
    "complex_CNN.add(Permute((1,3,2)))\n",
    "complex_CNN.add(Lambda(LC))\n",
    "complex_CNN.add(Permute((1,3,2)))\n",
    "complex_CNN.add(Activation('relu'))\n",
    "complex_CNN.add(Dropout(dr))\n",
    "complex_CNN.add(Convolution2D(80, (2, 3), padding='valid', activation=\"relu\", name=\"conv2\", kernel_initializer='glorot_uniform', data_format='channels_first'))\n",
    "complex_CNN.add(Dropout(dr))\n",
    "complex_CNN.add(Flatten())\n",
    "complex_CNN.add(Dense(256, activation='relu', kernel_initializer='he_normal', name=\"dense1\"))\n",
    "complex_CNN.add(Dropout(dr))\n",
    "complex_CNN.add(Dense( len(classes), kernel_initializer='he_normal', name=\"dense2\" ))\n",
    "complex_CNN.add(Activation('softmax'))\n",
    "complex_CNN.add(Reshape([len(classes)]))\n",
    "complex_CNN.compile(loss='categorical_crossentropy', optimizer='adam')\n",
    "complex_CNN.summary()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Parameterize the Training Process"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Number of epochs\n",
    "epochs = 100\n",
    "# Training batch size\n",
    "batch_size = 1024  "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Train the networks"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#train CNN2\n",
    "start = time.time()\n",
    "filepath = 'cnn2.wts.h5'\n",
    "history_cnn2 = cnn2.fit(X_train,\n",
    "    Y_train,\n",
    "    batch_size=batch_size,\n",
    "    epochs=epochs,\n",
    "    # show_accuracy=False,\n",
    "    verbose=2,\n",
    "    validation_data=(X_test, Y_test),\n",
    "    class_weight='auto',\n",
    "    callbacks = [\n",
    "        keras.callbacks.ModelCheckpoint(filepath, monitor='val_loss', verbose=0, save_best_only=True, mode='auto'),\n",
    "        keras.callbacks.EarlyStopping(monitor='val_loss', patience=5, verbose=0, mode='auto')\n",
    "    ])\n",
    "cnn2.load_weights(filepath)\n",
    "end = time.time()\n",
    "duration = end - start\n",
    "print('CNN2 Training time = ' + str(round(duration/60,5)) + 'minutes')\n",
    "\n",
    "#train CNN2-260\n",
    "start = time.time()\n",
    "filepath = 'cnn2_260.wts.h5'\n",
    "history_cnn2_260 = cnn2_260.fit(X_train,\n",
    "    Y_train,\n",
    "    batch_size=batch_size,\n",
    "    epochs=epochs,\n",
    "    # show_accuracy=False,\n",
    "    verbose=2,\n",
    "    validation_data=(X_test, Y_test),\n",
    "    class_weight='auto',\n",
    "    callbacks = [\n",
    "        keras.callbacks.ModelCheckpoint(filepath, monitor='val_loss', verbose=0, save_best_only=True, mode='auto'),\n",
    "        keras.callbacks.EarlyStopping(monitor='val_loss', patience=5, verbose=0, mode='auto')\n",
    "    ])\n",
    "cnn2_260.load_weights(filepath)\n",
    "end = time.time()\n",
    "duration = end - start\n",
    "print('CNN2-260 Training time = ' + str(round(duration/60,5)) + 'minutes')\n",
    "\n",
    "#train Complex\n",
    "start = time.time()\n",
    "filepath = 'complex.wts.h5'\n",
    "history_complex = complex_CNN.fit(X_train,\n",
    "    Y_train,\n",
    "    batch_size=batch_size,\n",
    "    epochs=epochs,\n",
    "    # show_accuracy=False,\n",
    "    verbose=2,\n",
    "    validation_data=(X_test, Y_test),\n",
    "    class_weight='auto',\n",
    "    callbacks = [\n",
    "        keras.callbacks.ModelCheckpoint(filepath, monitor='val_loss', verbose=0, save_best_only=True, mode='auto'),\n",
    "        keras.callbacks.EarlyStopping(monitor='val_loss', patience=5, verbose=0, mode='auto')\n",
    "    ])\n",
    "complex_CNN.load_weights(filepath)\n",
    "end = time.time()\n",
    "duration = end - start\n",
    "print('Complex Training time = ' + str(round(duration/60,5)) + 'minutes')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Explore Results"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Loss Curve Plots"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# CNN2 loss curves\n",
    "plt.figure(figsize = (7,7))\n",
    "plt.plot(history_cnn2.epoch, history_cnn2.history['loss'], label=\"CNN2: Training Error + Loss\")\n",
    "plt.plot(history_cnn2.epoch, history_cnn2.history['val_loss'], label=\"CNN2: Validation Error\")\n",
    "plt.xlabel('Epoch')\n",
    "plt.ylabel('% Error')\n",
    "plt.legend()\n",
    "plt.savefig('train_cnn2.pdf',transparent = True, bbox_inches = 'tight', pad_inches = 0.01)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# CNN2-260 loss curves\n",
    "plt.figure(figsize = (7,7))\n",
    "plt.plot(history_cnn2_260.epoch, history_cnn2_260.history['loss'], label=\"CNN2-260: Training Error + Loss\")\n",
    "plt.plot(history_cnn2_260.epoch, history_cnn2_260.history['val_loss'], label=\"CNN2-260: Validation Error\")\n",
    "plt.xlabel('Epoch')\n",
    "plt.ylabel('% Error')\n",
    "plt.legend()\n",
    "plt.savefig('train_cnn2_260.pdf',transparent = True, bbox_inches = 'tight', pad_inches = 0.01)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Complex loss curves\n",
    "plt.figure(figsize = (7,7))\n",
    "plt.plot(history_complex.epoch, history_complex.history['loss'], label='Complex: Training Error + Loss')\n",
    "plt.plot(history_complex.epoch, history_complex.history['val_loss'], label='Complex: Validation Error')\n",
    "plt.xlabel('Epoch')\n",
    "plt.ylabel('% Error')\n",
    "plt.legend()\n",
    "plt.savefig('train_complex.pdf',transparent = True, bbox_inches = 'tight', pad_inches = 0.01)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Confusion Matrices"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Create a function to plot the confusion matrices\n",
    "def plot_confusion_matrix(cm, title='', cmap=plt.cm.Blues, labels=[]):\n",
    "    plt.imshow(cm, interpolation='nearest', cmap=cmap)\n",
    "    plt.title(title)\n",
    "    tick_marks = np.arange(len(labels))\n",
    "    plt.xticks(tick_marks, labels, rotation=45)\n",
    "    plt.yticks(tick_marks, labels)\n",
    "    plt.tight_layout()\n",
    "    plt.ylabel('True label')\n",
    "    plt.xlabel('Predicted label')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Plot confusion matrix\n",
    "test_Y_hat = cnn2.predict(X_test, batch_size=batch_size)\n",
    "conf = np.zeros([len(classes),len(classes)])\n",
    "confnorm = np.zeros([len(classes),len(classes)])\n",
    "for i in range(0,X_test.shape[0]):\n",
    "    j = list(Y_test[i,:]).index(1)\n",
    "    k = int(np.argmax(test_Y_hat[i,:]))\n",
    "    conf[j,k] = conf[j,k] + 1\n",
    "for i in range(0,len(classes)):\n",
    "    confnorm[i,:] = conf[i,:] / np.sum(conf[i,:])\n",
    "cor = np.sum(np.diag(conf))\n",
    "ncor = np.sum(conf) - cor\n",
    "print(\"Overall Accuracy - CNN2: \", cor / (cor+ncor))\n",
    "acc = 1.0*cor/(cor+ncor)\n",
    "plt.figure()\n",
    "plot_confusion_matrix(confnorm, labels=classes)\n",
    "plt.savefig('Confusion_CNN2.jpg',transparent = True, bbox_inches = 'tight', pad_inches = 0.01)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Plot confusion matrix\n",
    "test_Y_hat = cnn2_260.predict(X_test, batch_size=batch_size)\n",
    "conf = np.zeros([len(classes),len(classes)])\n",
    "confnorm = np.zeros([len(classes),len(classes)])\n",
    "for i in range(0,X_test.shape[0]):\n",
    "    j = list(Y_test[i,:]).index(1)\n",
    "    k = int(np.argmax(test_Y_hat[i,:]))\n",
    "    conf[j,k] = conf[j,k] + 1\n",
    "for i in range(0,len(classes)):\n",
    "    confnorm[i,:] = conf[i,:] / np.sum(conf[i,:])\n",
    "cor = np.sum(np.diag(conf))\n",
    "ncor = np.sum(conf) - cor\n",
    "print(\"Overall Accuracy - CNN2-260: \", cor / (cor+ncor))\n",
    "acc = 1.0*cor/(cor+ncor)\n",
    "plt.figure()\n",
    "plot_confusion_matrix(confnorm, labels=classes)\n",
    "plt.savefig('Confusion_CNN2_260.jpg',transparent = True, bbox_inches = 'tight', pad_inches = 0.01)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Plot confusion matrix\n",
    "test_Y_hat = complex_CNN.predict(X_test, batch_size=batch_size)\n",
    "conf = np.zeros([len(classes),len(classes)])\n",
    "confnorm = np.zeros([len(classes),len(classes)])\n",
    "for i in range(0,X_test.shape[0]):\n",
    "    j = list(Y_test[i,:]).index(1)\n",
    "    k = int(np.argmax(test_Y_hat[i,:]))\n",
    "    conf[j,k] = conf[j,k] + 1\n",
    "for i in range(0,len(classes)):\n",
    "    confnorm[i,:] = conf[i,:] / np.sum(conf[i,:])\n",
    "cor = np.sum(np.diag(conf))\n",
    "ncor = np.sum(conf) - cor\n",
    "print(\"Overall Accuracy - Complex: \", cor / (cor+ncor))\n",
    "acc = 1.0*cor/(cor+ncor)\n",
    "plt.figure()\n",
    "plot_confusion_matrix(confnorm, labels=classes)\n",
    "plt.savefig('Confusion_Complex.jpg',transparent = True, bbox_inches = 'tight', pad_inches = 0.01)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Accuracy by SNR (Confusion Matrices @ -20 dB and 20 dB)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# create one hot labels\n",
    "labels_oh       = np.eye(11)\n",
    "samples_db      = np.zeros((20, 11000, 2, 128))\n",
    "truth_labels_db = np.zeros((20, 11000, 11))\n",
    "\n",
    "# Pull out the data by SNR\n",
    "for i in range(len(test_snrs)):\n",
    "    for j in range(len(mods)):\n",
    "        samples_db[i, j*1000:(j+1)*1000,:,:]    = Xd[(mods[j],test_snrs[i])]\n",
    "        truth_labels_db[i, j*1000:(j+1)*1000,:] = labels_oh[j]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Plot confusion matrix\n",
    "acc_cnn2 = np.zeros(len(test_snrs))\n",
    "for s in range(20):\n",
    "\n",
    "    # extract classes @ SNR\n",
    "#     test_SNRs = map(lambda x: lbl[x][1], test_idx)\n",
    "    test_X_i = samples_db[s]\n",
    "    test_Y_i = truth_labels_db[s]\n",
    "    \n",
    "    # estimate classes\n",
    "    test_Y_i_hat = cnn2.predict(test_X_i)\n",
    "    conf = np.zeros([len(mods),len(mods)])\n",
    "    confnorm = np.zeros([len(mods),len(mods)])\n",
    "    for i in range(0,test_X_i.shape[0]):\n",
    "        j = list(test_Y_i[i,:]).index(1)\n",
    "        k = int(np.argmax(test_Y_i_hat[i,:]))\n",
    "        conf[j,k] = conf[j,k] + 1\n",
    "    for i in range(0,len(mods)):\n",
    "        confnorm[i,:] = conf[i,:] / np.sum(conf[i,:])\n",
    "    #print the confusion matrix @ -20dB and 20dB\n",
    "    if s == 0 or s == 19:\n",
    "        plt.figure()\n",
    "        plot_confusion_matrix(confnorm, labels=classes)\n",
    "        plt.savefig('Confusion_CNN2_'+str(s)+'.jpg',transparent = True, bbox_inches = 'tight', pad_inches = 0.01)\n",
    "    cor = np.sum(np.diag(conf))\n",
    "    ncor = np.sum(conf) - cor\n",
    "#     print(\"Overall Accuracy: \", cor / (cor+ncor))\n",
    "    acc_cnn2[s] = 1.0*cor/(cor+ncor)\n",
    "# Save results to a pickle file for plotting later\n",
    "print(acc_cnn2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Plot confusion matrix\n",
    "acc_cnn2_260 = np.zeros(len(test_snrs))\n",
    "for s in range(20):\n",
    "\n",
    "    # extract classes @ SNR\n",
    "#     test_SNRs = map(lambda x: lbl[x][1], test_idx)\n",
    "    test_X_i = samples_db[s]\n",
    "    test_Y_i = truth_labels_db[s]\n",
    "    \n",
    "    # estimate classes\n",
    "    test_Y_i_hat = cnn2_260.predict(test_X_i)\n",
    "    conf = np.zeros([len(mods),len(mods)])\n",
    "    confnorm = np.zeros([len(mods),len(mods)])\n",
    "    for i in range(0,test_X_i.shape[0]):\n",
    "        j = list(test_Y_i[i,:]).index(1)\n",
    "        k = int(np.argmax(test_Y_i_hat[i,:]))\n",
    "        conf[j,k] = conf[j,k] + 1\n",
    "    for i in range(0,len(mods)):\n",
    "        confnorm[i,:] = conf[i,:] / np.sum(conf[i,:])\n",
    "    #print the confusion matrix @ -20dB and 20dB\n",
    "    if s == 0 or s == 19:\n",
    "        plt.figure()\n",
    "        plot_confusion_matrix(confnorm, labels=classes)\n",
    "        plt.savefig('Confusion_CNN2_260_'+str(s)+'.jpg',transparent = True, bbox_inches = 'tight', pad_inches = 0.01)    \n",
    "    cor = np.sum(np.diag(conf))\n",
    "    ncor = np.sum(conf) - cor\n",
    "#     print(\"Overall Accuracy: \", cor / (cor+ncor))\n",
    "    acc_cnn2_260[s] = 1.0*cor/(cor+ncor)\n",
    "# Save results to a pickle file for plotting later\n",
    "print(acc_cnn2_260)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Plot confusion matrix\n",
    "acc_complex = np.zeros(len(test_snrs))\n",
    "for s in range(20):\n",
    "\n",
    "    # extract classes @ SNR\n",
    "#     test_SNRs = map(lambda x: lbl[x][1], test_idx)\n",
    "    test_X_i = samples_db[s]\n",
    "    test_Y_i = truth_labels_db[s]\n",
    "    \n",
    "    # estimate classes\n",
    "    test_Y_i_hat = complex_CNN.predict(test_X_i)\n",
    "    conf = np.zeros([len(mods),len(mods)])\n",
    "    confnorm = np.zeros([len(mods),len(mods)])\n",
    "    for i in range(0,test_X_i.shape[0]):\n",
    "        j = list(test_Y_i[i,:]).index(1)\n",
    "        k = int(np.argmax(test_Y_i_hat[i,:]))\n",
    "        conf[j,k] = conf[j,k] + 1\n",
    "    for i in range(0,len(mods)):\n",
    "        confnorm[i,:] = conf[i,:] / np.sum(conf[i,:])\n",
    "    #print the confusion matrix @ -20dB and 20dB\n",
    "    if s == 0 or s == 19:\n",
    "        plt.figure()\n",
    "        plot_confusion_matrix(confnorm, labels=classes)\n",
    "        plt.savefig('Confusion_Complex_'+str(s)+'.jpg',transparent = True, bbox_inches = 'tight', pad_inches = 0.01)\n",
    "    cor = np.sum(np.diag(conf))\n",
    "    ncor = np.sum(conf) - cor\n",
    "#     print(\"Overall Accuracy: \", cor / (cor+ncor))\n",
    "    acc_complex[s] = 1.0*cor/(cor+ncor)\n",
    "# Save results to a pickle file for plotting later\n",
    "print(acc_complex)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Accuracy Comparison of the Architectures"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.scatter(range(-20,20,2),acc_cnn2*100, label = \"CNN2\", color = 'coral')\n",
    "plt.scatter(range(-20,20,2),acc_cnn2_260*100, label = \"CNN2-260\", color = 'lightblue')\n",
    "plt.scatter(range(-20,20,2),acc_complex*100, label = 'Complex', color = 'green')\n",
    "plt.grid()\n",
    "plt.xlabel('SNR (dB)')\n",
    "plt.ylabel('Accuracy (%)')\n",
    "plt.legend()\n",
    "plt.savefig('Classification_Accuracy_All.jpg',transparent = True, bbox_inches = 'tight', pad_inches = 0.01)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.scatter(range(-20,20,2),acc_cnn2*100, label = \"CNN2\", color = 'coral')\n",
    "plt.scatter(range(-20,20,2),acc_complex*100, label = 'Complex', color = 'green')\n",
    "plt.grid()\n",
    "plt.xlabel('SNR (dB)')\n",
    "plt.ylabel('Accuracy (%)')\n",
    "plt.legend()\n",
    "plt.savefig('Classification_Accuracy_CNN2.jpg',transparent = True, bbox_inches = 'tight', pad_inches = 0.01)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.scatter(range(-20,20,2),acc_cnn2_260*100, label = \"CNN2-260\", color = 'lightblue')\n",
    "plt.scatter(range(-20,20,2),acc_complex*100, label = 'Complex', color = 'green')\n",
    "plt.grid()\n",
    "plt.xlabel('SNR (dB)')\n",
    "plt.ylabel('Accuracy (%)')\n",
    "plt.legend()\n",
    "plt.savefig('Classification_Accuracy_CNN2_260.jpg',transparent = True, bbox_inches = 'tight', pad_inches = 0.01)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.scatter(range(-20,20,2), (acc_complex-acc_cnn2)/acc_cnn2*100, label = \"CNN2\", color = 'coral')\n",
    "plt.scatter(range(-20,20,2), (acc_complex-acc_cnn2_260)/acc_cnn2_260*100, label = \"CNN2-260\", color = 'lightblue')\n",
    "plt.xlabel('SNR (dB)')\n",
    "plt.ylabel('Classification Improvement (%)')\n",
    "plt.grid()\n",
    "plt.legend()\n",
    "plt.savefig('Classification_Improvement_All.jpg',transparent = True, bbox_inches = 'tight', pad_inches = 0.01)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.scatter(range(-20,20,2), (acc_complex-acc_cnn2)/acc_cnn2*100, label = \"CNN2\", color = 'coral')\n",
    "plt.xlabel('SNR (dB)')\n",
    "plt.ylabel('Classification Improvement (%)')\n",
    "plt.grid()\n",
    "plt.legend()\n",
    "plt.savefig('Classification_Improvement_CNN2.jpg',transparent = True, bbox_inches = 'tight', pad_inches = 0.01)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.scatter(range(-20,20,2), (acc_complex-acc_cnn2_260)/acc_cnn2_260*100, label = \"CNN2-260\", color = 'lightblue')\n",
    "plt.xlabel('SNR (dB)')\n",
    "plt.ylabel('Classification Improvement (%)')\n",
    "plt.grid()\n",
    "plt.legend()\n",
    "plt.savefig('Classification_Improvement_CNN2_260.jpg',transparent = True, bbox_inches = 'tight', pad_inches = 0.01)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
